《GNNExplainer: Generating Explanations for Graph Neural Networks》
图是强大的数据表示形式,而图神经网络(Graph Neural Network: GNN
)是处理图的最新技术。 GNN
能够递归地聚合图中邻域节点的信息,从而自然地捕获图结构和节点特征。
尽管GNN
效果很好,但是其可解释性较差。由于以下几个原因,GNN
的可解释性非常重要:
可以提高对 GNN
模型的信任程度。
在越来越多有关公平(fairness
)、隐私(privacy
)、以及其它安全挑战(safety challenge
)的关键决策应用(application
)中,提高模型的透明度(transparency
)。
允许从业人员理解模型特点,从而在实际环境中部署之前就能识别并纠正模型的错误。
尽管目前尚无用于解释 GNN
的方法。但是在更高层次,我们可以将包括 GNN
和 Non-GNN
的可解释性方法分为两个主要系列:
使用简单的代理模型(surrogate model
)局部逼近(locally approximate
)完整模型(full model
),然后探索这个代理模型以进行解释。
这可能是以模型无关(model-agnostic
)的方式来完成的,通常是通过线性模型(linear model
)或者规则集合(set of rules
)来学习预测的局部近似,可以充分代表预测结果。
仔细检查模型的相关特征(relevant feature
),并找到 high-level
特征的良好定性解释(good qualitative interpretation
),或者识别有影响力的输入样本(influential input instance
)。
如通过特征梯度(feature gradients
)、神经元反向传播中输入特征的贡献、反事实推理(counterfactual reasoning
)等。
上述两类方法专注于研究模型的固有解释,而后验(post-hoc
)的可解性方法将模型视为黑盒,然后对其进行探索从而获得相关信息。
但是所有这些方法无法融合关系信息,即图的结构信息。由于关系信息对于图上机器学习任务的成功至关重要,因此对 GNN
预测的任何解释都应该利用图提供的丰富关系信息、以及节点特征。论文 《GNNExplainer: Generating Explanations for Graph Neural Networks》
提出了 一种解释 GNN
预测的方法,称作 GNNEXPLAINER
。GNNEXPLAINER
接受训练好的GNN
及其预测,返回对于预测最有影响力的、输入图的一个小的子图(small subgraph
),以及节点特征的一个小的子集(small subset
)。
如下图所示,这里展示了一个节点分类任务,其中在社交网络上训练了 GNN
模型 GNN
GNNEXPLAINER
通过识别对预测 explanation
),如右图所示。
通过检查 GNN
预测
同样地,通过检查 GNN
预测
GNNEXPLAINER
方法和 GNN
模型无关,可以解释任何 GNN
在图的任何机器学习任务上的预测,包括节点分类、链接预测、图分类任务。它可以解释单个实例的预测(single-instance explanation
),也可以解释一组实例的预测(multi-instance explanation
)。
在单个实例预测的情况下,GNNEXPLAINER
解释了 GNN
对于特定样本的预测。
在多个实例预测的情况下,GNNEXPLAINER
提供了对于一组样本(如类别为“篮球”的所有节点)的预测的一致性解释(即这些预测的共同的解释)。
GNNEXPLAINER
将解释指定为训练 GNN
的整个输入图的某个子图,其中子图基于 GNN
的预测最大化互信息(mutual information
)。这是通过平均场变分近似(mean field variational approximation
),以及学习一个实值 graph mask
来实现的。这个 graph mask
选择了 GNN
输入图的最重要子图。同时,GNNEXPLAINER
还学习了一个 feature mask
,它可以掩盖不重要的节点特征。
论文在人工合成图、以及真实的图上评估GNNEXPLAINER
的效果。实验表明:GNNEXPLAINER
为 GNN
的预测提供了一致而简洁的解释。
虽然解释GNN
问题没有得到很好的研究,但相关的可解释性(interpretability
)和神经调试 (neural debugging
)的问题在机器学习中得到了大量的关注。在high-level
上,我们可以将那些 non-graph neural network
的可解释性方法分为两个主要方向:
第一个方向的方法制定了完整神经网络(full neural network
)的简单代理模型。这可以通过模型无关(model-agnostic
)的方式完成,通常是通过学习 prediction
周围的局部良好近似(例如通过线性模型或规则集合),代表预测的充分条件(sufficient condition
)。
第二个方向方法确定了计算的重要方面,例如,特征梯度(feature gradient
)、神经元的反向传播对输入特征的贡献、以及反事实推理。然而,这些方法产生的显著性映射(saliency map
)已被证明在某些情况下具有误导性,并且容易出现梯度饱和等问题。这些问题在离散输入(如图邻接矩阵)上更加严重,因为梯度值可能非常大而且位于一个非常小的区域(interval
)。正因为如此,这种方法不适合用来解释神经网络在图上的预测。
事后可解释性(post-hoc interpretability
)方法不是创建新的、固有可解释的模型,而是将模型视为黑箱,然后探测模型的相关信息。然而,还没有利用关系型结构(如graph
)方面的工作。解释图结构数据预测的方法的缺乏是有问题的,因为在很多情况下,图上的预测是由节点和它们之间的边的路径的复杂组合引起的。例如,在一些任务中,只有当图中存在另一条替代路径形成一个循环时,一条边才是重要的,而这两个特征,只有在一起考虑时,才能准确预测节点标签。因此,它们的联合贡献不能被建模为单个贡献的简单线性组合。
最后,最近的GNN
模型通过注意力机制增强了可解释性。然而,尽管学到的edge attention
值可以表明重要的图结构,但这些值对所有节点的预测都是一样的。因此,这与许多应用相矛盾,在这些应用中,一条边对于预测一个节点的标签是至关重要的,但对于另一个节点的标签则不是。此外,这些方法要么仅限于特定的GNN
架构,要么不能通过共同考虑图结构和节点特征信息来解释预测结果。
GNNEXPLAINER
提供了各种好处 ,包括可视化语义相关结构以进行解释的能力、以及提供洞察GNN
的错误的能力。
定义图
每个节点
不失一般性,我们考虑节点分类问题的可解释性。定义 GNN
模型 label
。
我们假设GNN
模型 GNN
模型
首先,模型计算每对节点pair
对之间传递的消息。节点pair
对
其中:MSG(.)
为消息函数;representation
;
然后,对于每个节点 GNN
聚合来自其邻域的所有消息:
其中:AGG(.)
为一个邻域聚合函数;
最后,对于每个节点 GNN
根据 representation
:
其中 UPDATE(.)
为节点状态更新函数。
最终节点 embedding
为 representation
:
对于采用 MSG,AGG,UPDATE
计算组成的任何 GNN
,我们的 GNNEXPLAINER
可以提供解释。
我们的洞察(insight
)是观察到:节点 computation graph
)是由 GNN
的neighborhood-based
聚合来定义,如下图所示。这个计算图完全决定了用于生成节点 GNN
如何生成节点 embedding
定义节点 binary
邻接矩阵 0
或1
; 也关联一个特征矩阵
GNN
模型
一旦GNN
模型学到这样的分布之后,对于节点 GNN
的类别预测结果为 A
所示。
正式地讲,GNNEXPLAINER
为预测 explanation
),记作
(small subgraph
),如图 A
所示。
small subset
)。F
表示通过 mask F
来遮盖,即:B
所示。
假设原始的节点特征集合为 mask F
遮盖之后的特征集合为:
它是原始特征集合的一个小的特征子集,且有:
下图中:
图 A
给出了一个 GNN
在节点
GNN
都会具有所有消息(包括不重要的消息)从而进行预测,这可能会稀释重要的消息。
GNNEXPLAINER
的目标是识别少量对于预测至关重要的重要特征和路径(绿色)。
图 B
表示 GNNEXPLAINER
通过学习节点特征mask
来确定
接下来我们详细描述 GNNEXPLAINER
。给定训练好的 GNN
模型 prediction
)(即单实例解释single-instance explanation
)、或者一组预测(即多实例解释multi-instance explanation
), GNNEXPLAINER
将通过识别对模型
在多实例解释中,GNNEXPLAINER
将每个实例的解释聚合在一起并自动抽取为一个原型(proto
)。这个原型代表每个实例解释的公共部分,即 proto
可以对所有这些实例进行解释。
给定一个节点 GNN
预测 mask
,这留待下一步讨论。
我们使用互信息(mutual information:MI
)来刻画子图的重要性,并将 GNNEXPLAINER
形式化为以下最优化问题:
其中:
GNN
对于节点 GNN
预测节点
注意:这里没有任何关于节点 GNN
预测得准不准,而是仅关心哪些因素和 GNN
预测结果相关。
其实是 ,即以原始图、原始特征矩阵来进行的预测所得到的熵。
GNN
预测结果不确定性程度。
MI
刻画了当节点
考虑
类似地,考虑
GNN
,节点
因此,对于预测 GNN
被限制在 uncertainty
)。在效果上,
理论上当 GNNEXPLAINER
旨在通过采取对预测提供最高互信息的
直接优化 GNNEXPLAINER
的目标函数很困难,因为 fractional adjacency matrix
),即 0~1.0
之间。此外我们施加约束
这种连续性松弛(continuous relaxation
)可以解释为 variational approximation
)。具体而言,我们将 random graph variable
),则目标函数变为:
我们假设目标函数是凸函数,则 Jensen
不等式给出以下的上界:
实际上由于神经网络的复杂性,凸性假设不成立。但是通过实验我们发现:优化带正则化的上述目标函数通常求得一个局部极小值,该局部极小值具有高质量的解释性。
为精确地估计 mean-field variational approximation
),并将 multivariate Bernoulli distribution
):
这允许我们估计对于平均场近似的期望从而获得
我们从实验观察到:尽管 GNN
是非凸的,但是这种近似(approximation
)结合一个可以提升离散型(discreteness
)的正则化器一起,结果可以收敛到良好的局部极小值。
可以通过使用邻接矩阵的计算图的掩码
其中:
mask
矩阵。
sigmoid
函数,它将 mask
映射到 0.0~1.0
之间。
GNNExplainer
的核心在于:用0.0 ~ 1.0
之间的mask
矩阵(待学习)来调整邻接矩阵,从而最小化预测的熵。但是,这种方法只关心哪个子图对预测结果最重要,并不关心哪个子图对ground-truth
最有帮助。可以通过标签类别和模型预测之间的交叉熵来修改上式中的条件熵,从而得到哪个子图对
ground-truth
最有帮助。
通过随机梯度下降来学习。
在某些应用(application
)中,我们不关心模型预测结果的
尽管有不同的动机和目标,在 Neural Relational Inference
中也发现了masking
方法。
最后,我们计算
为确定哪些节点特征对于预测 GNNEXPLAINER
针对 GNNEXPLAINER
考虑
我们通过一个 mask
来定义特征选择器:
其中 0
或 1
,当它为1
时表示保留对应特征,否则遮盖对应特征。因此 mask out
)的节点特征。
我们定义特征 mask
矩阵为:
则有:
现在我们在互信息目标函数中考虑节点特征,从而得到解释
该目标函数同时考虑了对预测
从直觉上看:
如果某个节点特征不重要,则 GNN
权重矩阵中的相应权重应该接近于零。mask
这类特征对于预测结果没有影响。
如果某个节点特征很重要,则 GNN
权重矩阵中相应权重应该较大。mask
这类特征会降低预测为
但是在某些情况下,这种方法会忽略对于预测很重要、但是特征取值接近于零的特征。为解决该问题,我们对所有特征子集边际化(marginalize
),并在训练过程中使用蒙特卡洛估计从
此外,我们使用 reparametrization
技巧将目标函的梯度反向传播到 mask
矩阵
具体而言,为了通过 reparametrize
其中:
上式等价于:
。因此 由两部分加权和得到:
:来自于每个维度边际分布采样得到的,权重为 ,代表噪音部分。这是为了解决特征取值接近于零但是又对于预测很重要的特征的问题。
:来自于子图节点的特征向量,权重为 ,代表真实信号部分。 这种特征可解释方法可以用于普通的神经网络模型。
为了在解释中加入更多属性,可以使用正则化项扩展 GNNEXPLAINER
的目标函数。可以包含很多正则化项从而产生具有所需属性的解释。
例如,我们使用逐元素的熵来鼓励结构mask
和节点特征mask
是离散的。
例如,我们可以将 mask
参数的所有元素之和作为正则化项,从而惩罚规模太大的mask
。
此外, GNNEXPLAINER
可以通过诸如拉格朗日乘子(Lagrange multiplier
)约束、或者额外的正则化项等技术来编码 domain-specific
约束。
最后需要重点注意的是:每个解释必须是一个有效的计算图。具体而言, GNN
的消息流向节点 GNN
做出预测
重要的是,GNNEXPLAINER
的解释一定是有效的计算图,因为它在整个计算图上优化结构 mask
。即使一条断开的边对于消息传递很重要,GNNEXPLAINER
也不会选择它作为解释,因为它不会影响 GNN
的预测结果。实际上,这意味着 small connected subgraph
)。
这是因为
GNNExplainer
会运行GNN
,如果计算图无效则运行GNN
的结果失败或者预测效果很差,因此也就不会作为可解释结果。
有时候我们需要回答诸如 “为什么 GNN
对于一组给定的节点预测都是类别 c
” 之类的问题。因此我们需要获得对于类别 c
的全局解释。
这里我们提出一个基于 GNNEXPLAINER
的解决方案,从而在类别 c
中的一组不同节点的各自单实例解释中,找到针对类别c
的通用的解释。这个问题与寻找每个解释图中最大公共子图密切相关,这是一个 NP-hard
问题。这里我们采用了解决该问题的神经网络方案,案称作基于对齐(alignment-based
)的 multi-instance GNNEXPLAINER
。
对于给定的类 c
,我们首先选择一个参考节点(reference node
) prototypical node
)。
可以通过计算类别 c
中所有节点的 embedding
均值,然后选择类别 c
中节点 embedding
和这个均值最近的节点作为参考节点。
也可以使用有关先验知识,选择和先验知识最匹配的节点作为类别 c
的参考节点。
给定类别 c
的参考节点 reference
解释图
利用微分池化(differentiable pooling
)的思想,我们使用一个松弛(relaxed
)的对齐矩阵(alignment matrix
)来找到解释图 reference
解释图 relaxed alignment matrix
)
其中:
1.0
。
上式第一项表示:经过对齐之后,
实际上对于两个大图
一旦得到类别 c
中所有节点对齐后的邻接矩阵,我们就可以使用中位数来生成一个原型(prototype
)。之所以使用中位数,是因为中位数可以有效对抗异常值。即:
其中 c
中第 explanation
的对齐后的邻接矩阵(即
原型 explanation
和类别原型进行比较,从而研究该特定节点。
在多个解释图的邻接矩阵对齐过程中,也可以使用现在的图库(graph library
)来寻找这些解释图的最大公共子图,从而替换掉神经网络部分。
在多实例解释中,解释器(explainer
)不仅必须突出与单个预测的局部相关信息,还需要强调不同实例之间更高level
的相关性。
这些实例之间可以通过任意方式产生关联,但是最常见的还是类成员(class-membership
)关联。假设类的不同样本之间存在共同特征,那么解释器需要捕获这种共同的特征。例如,通常发现诱变化合物 (mutagenic compounds
)具有某些特定属性的功能团,如 NO2
。
如下图所示,经验丰富的专家可能已经注意到这些功能团的存在。当 GNNEXPLAINER
生成原型(prototype
)时,可以进一步加强这方面的证据。下图来自于 MUTAG
数据集的诱变化合物。
机器学习任务的扩展:除了解释节点分类之外,GNNEXPLAINER
还可以解释链接预测和图分类,无需更改其优化算法。
在预测链接 GNNEXPLAINER
为链接的两个端点学习两个mask
在图分类时,目标函数中的邻接矩阵是图中所有节点邻接矩阵的并集( union
)。
注意:图分类任务和节点分类任务不同。由于图分类任务存在节点 embedding
的聚合,因此解释
模型扩展: GNNEXPLAINER
能够处理所有基于消息传递的GNN
,包括:Graph Convolutional Networks:GCN
、Gated Graph Sequence Neural Networks:GGS-NNs
、Jumping Knowledge Networks:JK-Net
、Attention Networks-GAT
、Graph Networks:GN
、具有各种聚合方案的 GNN
、Line-Graph NNs
、position-aware GNN
、以及很多其它 GNN
架构。
GNNEXPLAINER
优化中的参数规模取决于节点 GNNEXPLAINER
学习的。
但是,由于单个节点的计算图通常较小,因此即使完整的输入图很大 GNNEXPLAINER
仍然可以有效地生成解释。
数据集:
人工合成数据集:我们人工构建了四种节点分类数据集,如下表所示。
BA-SHAPES
数据集:我们从 300
个节点的 Barabasi-Albert:BA
基础图、以及一组80
个五节点的房屋(house
)结构的主题(motif
)开始,这些 motif
被随机添加到基础图的随机选择的节点上。进一步地我们添加
根据节点的结构角色,节点为以下四种类型之一:house
顶部节点、house
中间节点、house
底部节点、非house
节点。
BA-COMMUNITY
数据集:是两个 BA-SHAPES
图的并集。节点具有正态分布的特征向量,并且根据其结构角色、社区成员(两种社区)分配为8
种类别之一。
TREE-CYCLES
:从 8-level
平衡二叉树为基础图、一组 80
个 六节点的环状 motif
开始,这些 motif
随机添加到基础图的随机选择的节点上。
TREE-GRID
:和 TREE-CYCLES
相同,除了使用 3x3
的网格 motif
代替六节点的环 motif
之外。
真实数据集:我们考虑两个图分类数据集。
MUTAG
:包含 4337
个分子图的数据集,根据分子对于革兰氏阴性菌伤寒沙门氏菌(Gram-negative bacterium S.typhimurium
)的诱变作用(mutagenic effect
)进行标记。
REDDIT-BINARY
:包含 2000
个图的数据集,每个图代表Reddit
上讨论的话题 (thread
)。在每个图中,节点代表话题下参与讨论的用户,边代表一个用户对另一个用户的评论进行了回复。
图根据话题中用户交互类型进行标记:r/IAmA, r/AskReddit
包含 Question-Answer
交互, r/TrollXChromosomes and r/atheism
包含Online-Discussion
交互。
Baseline
方法:很多可解释性方法无法直接应用于图,尽管如此我们考虑了以下baseline
方法,这些方法可以为 GNN
的预测提供解释。
GRAD
:基于梯度的方法。我们计算损失函数对于邻接矩阵的梯度、损失函数对于节点特征的梯度,这类似于显著性映射方法 (saliency map approach
)。
ATT
:基于graph attention GNN:GAT
的方法。它学习计算图中边的注意力权重,并将其视为边的重要性。
尽管 ATT
考虑了图结构,但是它并未考虑节点特征的解释,而且仅能解释 GAT
模型。
此外,由于环(cycle
)的存在(如下图所示),节点的 1hop
邻居也是它的 2-hop
邻居。因此使用哪个注意力权重(1hop vs 2hop
)也不是很清楚。通常我们将这些 hop
的注意力权重取均值。
实验配置:对于每个数据集,我们首先为这个数据集训练一个 GNN
,然后使用 GARD
和 GNNEXPLAINER
来对 GNN
的预测做出解释。
注意,ATT baseline
需要使用 GAT
之类的图注意力架构,因此我们在同一个数据集上单独训练了一个 GAT
模型,并使用学到的边注意力权重进行解释。
我们对所有的节点分类任务、图分类任务中调整权重正则化参数。这些超参数在所有实验中使用。
子图大小正则化超参数为 0.005
,该正则化倾向于得到尽可能小的子图。
拉普拉斯正则化参数为 0.5
。
特征数量正则化参数为 0.1
,该正则化倾向于得到尽可能少的unmasked
特征。
我们使用 Adam
优化器训练 GNN
和 解释方法 (explaination methods
)。
所有 GNN
模型都训练 1000
个 epoch
,学习率为 0.001
, 从而对节点分类数据集达到至少 85%
的准确率、对于图分类数据集达到至少95%
的准确率。
对于所有数据集,train/valid/test
拆分比例为 80%:10%:10%
。
GNNEXPLAINER
使用相同的优化器和学习率,并训练 100 ~300
个 epoch
。
因为 GNNEXPLAINER
仅需要在少于 100
个节点的局部计算图上进行训练,因此训练 epoch
要更少一些。
为了抽取解释子图 GRAD
的梯度、ATT
的注意力权重、GNNEXPLAINER
的 masked
邻接矩阵)。然后我们使用一个阈值来删除权重较低的边,从而得到
对于所有方法,我们执行线性搜索从而找到临界阈值,使得
所有数据集的 ground truth explanation
是连接的子图。
对于节点分类,我们将不同方法得到的 GNNEXPLAINER
方法来讲,
对于图分类,我们抽取
超参数
对于人工合成数据集,我们将 ground truth
的大小。
对于真实世界数据集,我们设置
定量分析:对于人工合成数据集,我们已有 ground-truth
解释,然后使用这些ground-truth
来评估所有方法解释的准确性。具体而言,我们将解释问题形式化为二元分类任务,其中真实解释中的边视为label
,而将可解释性方法给出的重要性权重视为预测得分。一种更好的可解释性方法对于真实解释的边的预测得分较高,从而获得更好的解释准确率。
下表给出了人工合成数据集节点分类评估结果。实验结果表明:GNNEXPLAINER
的平均效果相比其它方法高出 17.1%
。
定性分析:
在没有节点特征的 topology-based
预测任务中(如 BA-SHAPES、TREE-CYCLES
),GNNEXPLAINER
正确地识别解释节点标签的motif
。
如下图所示,A-B
给出了四个人工合成数据集上节点分类任务的单实例解释子图,每种方法都为红色节点的预测提供解释(绿色表示重要的节点,橙色表示不重要的节点)。可以看到 GNNEXPLAINER
能识别到 house, cycle, trid
等 motif
,而 baseline
方法无法识别。
我们研究图分类任务的解释。
在 MUTAG
实例中,颜色表示节点特征,这代表原子类型(氢H
、碳C
等)。GNNEXPLAINER
可以正确的识别对于图类别比较重要的碳环、以及化学基团 NH2
和 NO2
,它们确实已知是诱变的 (mutagenic
)官能团。
在 REDDIT-BINARY
示例中,我们看到Question-Answe
图(B
的第二行)具有2~3
个同时连接到很多低 degree
节点的高 degree
节点。这是讲得通的,因为在 Reddit
的问答模式的话题中,通常具有 2~3
位专家都回答了许多不同的问题。
相反,在 REDDIT-BINARY
的讨论模式(discussion pattern
)图(A
的第二行),通常表现出树状模式。
GRAD,ATT
方法给出了错误的或者不完整的解释。例如两种baseline
都错过了 MUTAG
数据集中的碳环。
此外,尽管 ATT
可以将边注意力权重视为消息传递的重要性得分,但是权重在输入图中的所有节点之间共享,因此 ATT
无法提供高质量的单实例解释。
解释(explanations
)的基本要求是它们必须是可解释的(interpretable
),即,提供对输入节点和预测之间关系的定性理解。下图显式了一个实验结果,其中给出不同方法预测的解释的特征。特征重要性通过热力度可视化。
可以看到:GNNEXPLAINER
确实识别出了重要的特征;而 gradient-based
无法识别,它为无关特征提供了较高的重要性得分。
ground-truth
特征从何而来?作者并未讲清楚。